{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recurrent Neural Network Example\n",
    "\n",
    "Build a recurrent neural network (LSTM) with TensorFlow 2.0.\n",
    "\n",
    "- Author: Aymeric Damien\n",
    "- Project: https://github.com/aymericdamien/TensorFlow-Examples/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN Overview\n",
    "\n",
    "<img src=\"http://colah.github.io/posts/2015-08-Understanding-LSTMs/img/RNN-unrolled.png\" alt=\"nn\" style=\"width: 600px;\"/>\n",
    "\n",
    "References:\n",
    "- [Long Short Term Memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf), Sepp Hochreiter & Jurgen Schmidhuber, Neural Computation 9(8): 1735-1780, 1997.\n",
    "\n",
    "## MNIST Dataset Overview\n",
    "\n",
    "This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 1. For simplicity, each image has been flattened and converted to a 1-D numpy array of 784 features (28*28).\n",
    "\n",
    "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n",
    "\n",
    "To classify images using a recurrent neural network, we consider every image row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 timesteps for every sample.\n",
    "\n",
    "More info: http://yann.lecun.com/exdb/mnist/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "# Import TensorFlow v2.\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import Model, layers\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# MNIST dataset parameters.\n",
    "num_classes = 10 # total classes (0-9 digits).\n",
    "num_features = 784 # data features (img shape: 28*28).\n",
    "\n",
    "# Training Parameters\n",
    "learning_rate = 0.001\n",
    "training_steps = 1000\n",
    "batch_size = 32\n",
    "display_step = 100\n",
    "\n",
    "# Network Parameters\n",
    "# MNIST image shape is 28*28px, we will then handle 28 sequences of 28 timesteps for every sample.\n",
    "num_input = 28 # number of sequences.\n",
    "timesteps = 28 # timesteps.\n",
    "num_units = 32 # number of neurons for the LSTM layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare MNIST data.\n",
    "from tensorflow.keras.datasets import mnist\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "# Convert to float32.\n",
    "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n",
    "# Flatten images to 1-D vector of 784 features (28*28).\n",
    "x_train, x_test = x_train.reshape([-1, 28, 28]), x_test.reshape([-1, num_features])\n",
    "# Normalize images value from [0, 255] to [0, 1].\n",
    "x_train, x_test = x_train / 255., x_test / 255."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use tf.data API to shuffle and batch data.\n",
    "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
    "train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create LSTM Model.\n",
    "class LSTM(Model):\n",
    "    # Set layers.\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        # RNN (LSTM) hidden layer.\n",
    "        self.lstm_layer = layers.LSTM(units=num_units)\n",
    "        self.out = layers.Dense(num_classes)\n",
    "\n",
    "    # Set forward pass.\n",
    "    def call(self, x, is_training=False):\n",
    "        # LSTM layer.\n",
    "        x = self.lstm_layer(x)\n",
    "        # Output layer (num_classes).\n",
    "        x = self.out(x)\n",
    "        if not is_training:\n",
    "            # tf cross entropy expect logits without softmax, so only\n",
    "            # apply softmax when not training.\n",
    "            x = tf.nn.softmax(x)\n",
    "        return x\n",
    "\n",
    "# Build LSTM model.\n",
    "lstm_net = LSTM()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Cross-Entropy Loss.\n",
    "# Note that this will apply 'softmax' to the logits.\n",
    "def cross_entropy_loss(x, y):\n",
    "    # Convert labels to int 64 for tf cross-entropy function.\n",
    "    y = tf.cast(y, tf.int64)\n",
    "    # Apply softmax to logits and compute cross-entropy.\n",
    "    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)\n",
    "    # Average loss across the batch.\n",
    "    return tf.reduce_mean(loss)\n",
    "\n",
    "# Accuracy metric.\n",
    "def accuracy(y_pred, y_true):\n",
    "    # Predicted class is the index of highest score in prediction vector (i.e. argmax).\n",
    "    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))\n",
    "    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)\n",
    "\n",
    "# Adam optimizer.\n",
    "optimizer = tf.optimizers.Adam(learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimization process. \n",
    "def run_optimization(x, y):\n",
    "    # Wrap computation inside a GradientTape for automatic differentiation.\n",
    "    with tf.GradientTape() as g:\n",
    "        # Forward pass.\n",
    "        pred = lstm_net(x, is_training=True)\n",
    "        # Compute loss.\n",
    "        loss = cross_entropy_loss(pred, y)\n",
    "        \n",
    "    # Variables to update, i.e. trainable variables.\n",
    "    trainable_variables = lstm_net.trainable_variables\n",
    "\n",
    "    # Compute gradients.\n",
    "    gradients = g.gradient(loss, trainable_variables)\n",
    "    \n",
    "    # Update weights following gradients.\n",
    "    optimizer.apply_gradients(zip(gradients, trainable_variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step: 100, loss: 1.663173, accuracy: 0.531250\n",
      "step: 200, loss: 1.034144, accuracy: 0.750000\n",
      "step: 300, loss: 0.775579, accuracy: 0.781250\n",
      "step: 400, loss: 0.840327, accuracy: 0.781250\n",
      "step: 500, loss: 0.344379, accuracy: 0.937500\n",
      "step: 600, loss: 0.884484, accuracy: 0.718750\n",
      "step: 700, loss: 0.569674, accuracy: 0.875000\n",
      "step: 800, loss: 0.401931, accuracy: 0.906250\n",
      "step: 900, loss: 0.530193, accuracy: 0.812500\n",
      "step: 1000, loss: 0.265871, accuracy: 0.968750\n"
     ]
    }
   ],
   "source": [
    "# Run training for the given number of steps.\n",
    "for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):\n",
    "    # Run the optimization to update W and b values.\n",
    "    run_optimization(batch_x, batch_y)\n",
    "    \n",
    "    if step % display_step == 0:\n",
    "        pred = lstm_net(batch_x, is_training=True)\n",
    "        loss = cross_entropy_loss(pred, batch_y)\n",
    "        acc = accuracy(pred, batch_y)\n",
    "        print(\"step: %i, loss: %f, accuracy: %f\" % (step, loss, acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
